import random
import argparse
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

from core.MDP import MDP
from utils.random_board import generate_random_board, generate_one_goal_board
from utils.utils import *
from models.Random_model import RandomMDP
from models.Grid_world import Grid_world

EPSILON = 1e-21

def run(model: RandomMDP,
        max_iter=10000,
        param_list=[[{"alg": "Softmax_NPG"}, .1]],
        exp_mode="multi-run",
        metric="rho",
        noise=None,
        seed=21,
        clip=1e-13,
        path="./Figures/Exp_Exp.pdf"):
    '''
        param_list:
            [[mode1, step1], [mode2, step2] ... ]
    '''
    assert exp_mode in ["multi-run", "local-rate", "policy-converge", "single-run"]
    if exp_mode != "multi-run":
        assert len(param_list) == 1, "Only support one algorithm."
        assert noise is None, "Please do not add noise."
    
    log_diff_dict = {str(param): [] for param in param_list}
    model.solve_mdp(mode={"alg": "policy_iteration"},
                    epsilon=EPSILON)
    
    V_star = model.mdp.V.copy()
    delta = model.mdp.compute_delta()

    for (mode, step_size) in tqdm(param_list):
        model.mdp.init_policy_and_V(random_init=True, seed=seed)
        return_dict = model.solve_mdp(mode=mode,
                                        max_iter=max_iter,
                                        step_size=step_size,
                                        verbose=True,
                                        need_return=True,
                                        noise=noise,
                                        seed=seed)
        V_list = return_dict["V_list"]
        if exp_mode == "policy-converge":
            policy_list = return_dict["policy_list"]
        if metric == "rho":
            log_diff_list = np.array([np.log((V_star - V).mean() + EPSILON) for V in V_list])
        elif metric == "infty":
            log_diff_list = np.array([np.log((V_star - V).max() + EPSILON) for V in V_list])
        elif metric == "random-rho":
            metric_rho = np.random.uniform(0, 1, size=(model.mdp.S_size,))
            metric_rho = metric_rho / metric_rho.sum()
            log_diff_list = np.array([np.log(np.dot(V_star - V, metric_rho) + EPSILON) for V in V_list])
        
        log_diff_dict[str([mode, step_size])] = (log_diff_list)
        log_diff_list = log_diff_list[log_diff_list > clip]
        max_iter = min(max_iter, len(log_diff_list))
        
    if exp_mode == "multi-run":
    # Plot the curve.
        fig = plt.figure(figsize=(5,4))
        ax = plt.axes()
        for (mode, step_size) in param_list:   
            diff_lists = log_diff_dict[str([mode, step_size])]
            label = mode.get("label", mode["alg"])
            ax.plot(np.arange(max_iter), diff_lists[:max_iter], '-', label=str(label))

        ax.set_xlabel("iters")
        ax.set_ylabel("log value error")
        ax.legend()
        ax.grid(True)
        ax.grid(alpha=0.3)
        plt.gca().xaxis.set_major_locator(plt.MaxNLocator(5))
        # plt.show()
        plt.savefig(path)
        
    elif exp_mode == "single-run":
        fig = plt.figure(figsize=(5,4))
        ax = plt.axes()
        for (mode, step_size) in param_list:   
            diff_lists = log_diff_dict[str([mode, step_size])]
            label = mode.get("label", mode["alg"])
            ax.plot(np.arange(max_iter), diff_lists[:max_iter], '-', label=str(label))

        ax.set_xlabel("iters")
        ax.set_ylabel("log value error")
        ax.grid(True)
        ax.grid(alpha=0.3)
        plt.gca().xaxis.set_major_locator(plt.MaxNLocator(5))
        # plt.show()
        plt.title(label=str(label))
        plt.savefig(path)
    
    elif exp_mode == "local-rate":
        assert mode.get("phi", None) is not None, "Please provide the phi function."
        phi = mode["phi"]
        diff_lists = log_diff_dict[str([mode, step_size])]
        fig = plt.figure(figsize=(5,4))
        ax = plt.axes()
        ax.plot(np.arange(max_iter), diff_lists[:max_iter], '-', label=str("Convergence curve"))
        # ax.plot(diff_lists, '-', label=str("Convergence curve"))
    
        theory_rate = np.log(phi(-step_size * delta) / phi(0))
        # Compute the interception.
        # max_iter = len(diff_lists)
        a = max_iter // 100 * 99
        v_a = diff_lists[a]
        interception = v_a - theory_rate * a
        
        theory_lists = np.arange(max_iter) * theory_rate + interception
        ax.plot(theory_lists, '--', color="red", label=str("Local convergence rate"))
        
        ax.set_xlabel("iters")
        ax.set_ylabel("log value error")
        ax.legend()
        ax.grid(True)
        ax.grid(alpha=0.3)        
        # plt.title(mode["label"])
        plt.gca().xaxis.set_major_locator(plt.MaxNLocator(5))
        # plt.show()
        plt.savefig(path)
        
    elif exp_mode == "policy-converge":
        last_policy = policy_list[-1]
        _temp = []
        for s in tqdm(range(model.mdp.S_size)):
            Divergence_list = np.array([np.linalg.norm(last_policy[s] - policy[s], ord=1) for policy in policy_list])
            _temp.append(Divergence_list)
        
        Divergence_result = np.mean(np.array(_temp), axis=0)
        
        # fig = plt.figure(figsize=(5,4))
        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(4,8))
        # ax = plt.axes()
        ax1.plot(Divergence_result, '-', label=str("$|| \pi^k - \pi^\mathrm{last} ||_1$"))
        ax1.legend()
        ax1.grid(True)
        ax1.grid(alpha=0.3)        
        # plt.show()
        # print(last_policy)
        
        ax2 = model.visualize_prob_policy(ax2, verbose=True)
        fig.savefig(path, pad_inches=0.2, bbox_inches="tight")
        
        

if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument('--type', default='all')
    args = parser.parse_args()
    
    exp_name = args.type

    S_size = 50
    A_size = 10
    gamma = .7

    seed = 42
    model = RandomMDP(S_size=S_size,
                        A_size=A_size,
                        gamma=gamma,
                        seed=seed)     
    
    H, W = 5, 5
    board = generate_one_goal_board(H, W, random=False)
        
    grid_world = Grid_world(board,
                            gamma=0.9,
                            win_reward=1,
                            punish_reward=0)      
    
    
    if exp_name == "Local-1" or exp_name == "all":
        run(
            model=model,
            max_iter=10000,
            param_list=[
                [{"alg": "phi", "label": "softmax NPG", "phi": phi_exp_inv_factory(1, 1)}, 1],
            ],
            metric="rho",
            exp_mode="local-rate",
            noise=None,
            seed=seed,
            path="./Figures/Local_Exp1.pdf",
            clip=-16
        )
    
    if exp_name == "Local-2" or exp_name == "all":
        run(
            model=model,
            max_iter=20000,
            param_list=[
                [{"alg": "phi", "label": "Sigmoid", "phi": phi_sigmoid_factory()}, 1],
            ],
            metric="rho",
            exp_mode="local-rate",
            noise=None,
            seed=seed,
            path="./Figures/Local_Exp2.pdf",
            clip=-16
        )
    
    if exp_name == "Local-3" or exp_name == "all":
        run(
            model=model,
            max_iter=40000,
            param_list=[
                [{"alg": "phi", "label": "tan", "phi": phi_tan_factory()}, 0.1],
            ],
            metric="rho",
            exp_mode="local-rate",
            noise=None,
            seed=seed,
            path="./Figures/Local_Exp3.pdf",
            clip=-16
        )
        
    if exp_name == "Policy-1" or exp_name == "all":
        run(
            model=grid_world,
            max_iter=10000,
            param_list=[
                [{"alg": "phi", "label": "softmax NPG", "phi": phi_exp_inv_factory(1, 1)}, 1],
            ],
            metric="rho",
            exp_mode="policy-converge",
            noise=None,
            seed=seed,
            path="./Figures/Policy_Exp1.pdf",
            clip=-16
        )
        
    if exp_name == "Policy-2" or exp_name == "all":
        run(
            model=grid_world,
            max_iter=10000,
            param_list=[
                [{"alg": "phi", "label": "Sigmoid", "phi": phi_sigmoid_factory()}, 1],
            ],
            metric="rho",
            exp_mode="policy-converge",
            noise=None,
            seed=seed,
            path="./Figures/Policy_Exp2.pdf",
            clip=-16
        )
        
    if exp_name == "Policy-3" or exp_name == "all":
        run(
            model=grid_world,
            max_iter=10000,
            param_list=[
                [{"alg": "phi", "label": "tan", "phi": phi_tan_factory()}, 0.1],
            ],
            metric="rho",
            exp_mode="policy-converge",
            noise=None,
            seed=seed,
            path="./Figures/Policy_Exp3.pdf",
            clip=-16
        )                
        
    if exp_name == "Exp" or exp_name == "all":
        run(
            model=model,
            max_iter=10000,
            param_list=[
                [{"alg": "phi", "label": "softmax NPG $(p=1, q=1)$", "phi": phi_exp_inv_factory(1, 1)}, 1],
                [{"alg": "phi", "label": "$p=3, q=5, \delta=0.01$", "phi": phi_exp_refined_factory(3, 5, 0.01)}, 1],
                # [{"alg": "phi", "label": "Exp(3,5,0,01)", "phi": phi_exp_refined_factory(3, 5, 0.01)}, 1],
                # [{"alg": "phi", "label": "Exp(3,5,0.001)", "phi": phi_exp_refined_factory(3, 5, 0.001)}, 1],
                [{"alg": "phi", "label": "$p=5, q=7, \delta=0.01$", "phi": phi_exp_refined_factory(5, 7, 0.01)}, 1],
                [{"alg": "phi", "label": "$p=5, q=3$", "phi": phi_exp_refined_factory(5, 3, 0)}, 1],
                [{"alg": "phi", "label": "$p=7, q=5$", "phi": phi_exp_refined_factory(7, 5, 0)}, 1]
            ],
            metric="rho",
            exp_mode="multi-run",
            noise=None,
            seed=seed,
            path="./Figures/Exp_Exp.pdf",
            clip=-16
        )
        
    if exp_name == "Poly" or exp_name == "all":
        step_size = lambda p: (1-model.mdp.gamma) ** 3  / (10 * p ** 2 * model.mdp.A_size ** (2/p))
        run(
            model=model,
            max_iter=30000,
            param_list=[
                [{"alg": "phi", "label": "Poly(2)", "phi": phi_poly_factory(2), "step_include_d": True, "init_type": "escort_2"}, 0.01],
                [{"alg": "escort_normalized", "label": "Escort(4)", "p": 4}, 0.01],  
                [{"alg": "phi", "label": "Poly(4)", "phi": phi_poly_factory(4), "step_include_d": True, "init_type": "escort_4"}, 0.01],
                [{"alg": "escort_normalized", "label": "Escort(6)", "p": 6}, 0.01],  
                [{"alg": "phi", "label": "Poly(6)", "phi": phi_poly_factory(6), "step_include_d": True, "init_type": "excort_6"}, 0.01],
            ],
            metric="rho",
            exp_mode="multi-run",
            noise=None,
            seed=seed,
            path="./Figures/Poly_Exp.pdf",
            clip=-10
        )        